package com.ouyunc.oauth2.exception;

import com.ouyunc.oauth2.exception.handler.Auth2GlobalFilterExceptionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InsufficientScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.provider.error.WebResponseExceptionTranslator;
import org.springframework.security.web.util.ThrowableAnalyzer;
import org.springframework.stereotype.Component;
import org.springframework.web.HttpRequestMethodNotSupportedException;

import java.io.IOException;

/**
 * @Author fangzhenxun
 * @Description 自定义oauth2 异常转换类， 模仿 DefaultWebResponseExceptionTranslator来写
 * @Date 2020/4/22 17:29
 **/
@Component
public class IOAuth2WebResponseExceptionTranslator implements WebResponseExceptionTranslator<OAuth2Exception> {
    Logger log = LoggerFactory.getLogger(IOAuth2WebResponseExceptionTranslator.class);

    /**
     * security的异常分析
     **/
    private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();

    public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
        this.throwableAnalyzer = throwableAnalyzer;
    }

    /**
     * @param e
     * @return org.springframework.http.ResponseEntity<org.springframework.security.oauth2.common.exceptions.OAuth2Exception>
     * @Author fangzhenxun
     * @Description 认证服时发生的400异常在这里能捕获到，在这里我们可以将我们的异常信息封装成统一的格式返回即可
     * @Date 2020/4/22 17:31
     **/
    @Override
    public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {
        // 尝试从stacktrace中提取SpringSecurityException
        Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);

        // 异常栈获取 OAuth2Exception 异常
        Exception ase = (InvalidGrantException) throwableAnalyzer.getFirstThrowableOfType(InvalidGrantException.class, causeChain);
        if (ase != null) {
            return handleOAuth2Exception(new InvalidUsernamePasswordException(e.getMessage(), e));
        }
        // 异常栈获取 OAuth2Exception 异常
        ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);
        if (ase != null) {
            return handleOAuth2Exception(new OAuth2Exception(e.getMessage(), e));
        }
        //没有权限异常,认证异常
        ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class, causeChain);
        if (ase != null) {
            return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e));
        }

        //权限不足异常
        ase = (AccessDeniedException) throwableAnalyzer.getFirstThrowableOfType(AccessDeniedException.class, causeChain);
        if (ase instanceof AccessDeniedException) {
            return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase));
        }

        //方法不允许异常
        ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(HttpRequestMethodNotSupportedException.class, causeChain);
        if (ase instanceof HttpRequestMethodNotSupportedException) {
            return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase));
        }

        //非法参数异常
        ase = (IllegalArgumentException) throwableAnalyzer.getFirstThrowableOfType(IllegalArgumentException.class, causeChain);
        if (ase instanceof IllegalArgumentException) {
            return handleOAuth2Exception(new IllegalArgument(ase.getMessage(), ase));
        }

        // 不包含上述异常则服务器内部错误
        return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));

    }


    /**
     * @param e
     * @return org.springframework.http.ResponseEntity<org.springframework.security.oauth2.common.exceptions.OAuth2Exception>
     * @Author fangzhenxun
     * @Description
     * @Date 2020/4/23 10:18
     **/
    private ResponseEntity<OAuth2Exception> handleOAuth2Exception(OAuth2Exception e) throws IOException {

        int code = e.getHttpErrorCode();
        //设置响应头
        HttpHeaders headers = new HttpHeaders();
        headers.set("Cache-Control", "no-store");
        headers.set("Pragma", "no-cache");
        //如果没权限或作用域不足
        if (code == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
            headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
        }
        IOAuth2Exception exception = new IOAuth2Exception(e.getMessage() == null ? "Bad Request" : e.getMessage());
        exception.setCode(code);
        log.error("oauth2 异常：{}", exception.getMessage());
        ResponseEntity<OAuth2Exception> response = new ResponseEntity<>(exception, headers, HttpStatus.valueOf(code));
        return response;
    }


    @SuppressWarnings("serial")
    private static class ForbiddenException extends OAuth2Exception {

        public ForbiddenException(String msg, Throwable t) {
            super(msg, t);
        }


        @Override
        public String getOAuth2ErrorCode() {
            return "权限不足";
        }

        @Override
        public int getHttpErrorCode() {
            return 403;
        }

    }

    private static class ServerErrorException extends OAuth2Exception {

        public ServerErrorException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "服务器内部错误";
        }

        @Override
        public int getHttpErrorCode() {
            return 500;
        }

    }

    private static class InvalidUsernamePasswordException extends OAuth2Exception {

        public InvalidUsernamePasswordException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "Bad credentials";
        }

        @Override
        public int getHttpErrorCode() {
            return 401;
        }

    }

    private static class UnauthorizedException extends OAuth2Exception {

        public UnauthorizedException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "未授权";
        }

        @Override
        public int getHttpErrorCode() {
            return 401;
        }

    }

    private static class MethodNotAllowed extends OAuth2Exception {
        public MethodNotAllowed(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "请求方式不允许";
        }

        @Override
        public int getHttpErrorCode() {
            return 405;
        }

    }


    private static class IllegalArgument extends OAuth2Exception {
        public IllegalArgument(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "非法参数: ";
        }

        @Override
        public int getHttpErrorCode() {
            return 400;
        }

    }

}
